mediate.nnet <- function(formula.y,
                         formula.m1,
                         formula.m2 = NULL,
                         formula.m3 = NULL,
                         type.m1 = c("continuous", "binary"),
                         type.m2 = c("continuous", "binary"),
                         type.m3 = c("continuous", "binary"),
                         data, 
                         weights = NULL,
                         imp = FALSE,
                         S = 1000L,
                         seed = 19890213,
                         reorder.y = FALSE,
                         x.name,
                         x.shift = c("unit", "sd")) { 
  ## Seed
  set.seed(seed)
  
  ## Number of mediators
  MEDS <- sum(c(!is.null(formula.m1), 
                !is.null(formula.m2), 
                !is.null(formula.m3)))
  
  ## Model Estimation
  print("1 Estimating models")
  if (imp) {
    M <- length(data)
    b.y <- v.y <- b.m1 <- v.m1 <- list()
    if (!is.null(formula.m2)) b.m2 <- v.m2 <- list()
    if (!is.null(formula.m3)) b.m3 <- v.m3 <- list()
    for (i in seq_len(M)) {
      df <- data[[i]]
      w <- if (is.null(weights)) {
        NULL
      } else {
        df[, weights]
      }
      df$weights <- w
      
      ## Outcome
      mod.y <- multinom(formula.y,
                        Hess = TRUE,
                        data = df,
                        weights = weights,
                        maxit = 500L)
      v.y[[i]] <- vcov(mod.y)
      b.y[[i]] <- coef(mod.y)
      
      ## Mediator 1
      if (type.m1 == "continuous") {
        mod.m1 <- lm(formula.m1,
                     data = df,
                     weights = weights)
      } else if (type.m1 == "binary") {
        mod.m1 <- glm(formula.m1,
                      data = df,
                      weights = weights,
                      family = binomial(link = 'logit'))
      }
      v.m1[[i]] <- vcov(mod.m1)
      b.m1[[i]] <- coef(mod.m1)
      
      ## Mediator 2
      if (!is.null(formula.m2)) {
        if (type.m2 == "continuous") {
          mod.m2 <- lm(formula.m2,
                       data = df,
                       weights = weights)
        } else if (type.m2 == "binary") {
          mod.m2 <- glm(formula.m2,
                        data = df,
                        weights = weights,
                        family = binomial(link = 'logit'))
        }
        v.m2[[i]] <- vcov(mod.m2)
        b.m2[[i]] <- coef(mod.m2)
      }
      
      ## Mediator 3
      if (!is.null(formula.m3)) {
        if (type.m3 == "continuous") {
          mod.m3 <- lm(formula.m3,
                       data = df,
                       weights = weights)
        } else if (type.m3 == "binary") {
          mod.m3 <- glm(formula.m3,
                        data = df,
                        weights = weights,
                        family = binomial(link = 'logit'))
        }
        v.m3[[i]] <- vcov(mod.m3)
        b.m3[[i]] <- coef(mod.m3)
      }
    }
    rm(df)
  } else {
    M <- 1L
    w <- if (is.null(weights)) {
      NULL
    } else {
      data[, weights]
    }
    data$weights <- w
    
    ## Outcome
    mod.y <- multinom(formula.y,
                      Hess = TRUE,
                      data = data,
                      weights = weights,
                      maxit = 500L)
    v.y <- vcov(mod.y)
    b.y <- coef(mod.y)
    
    ## Mediator 1
    if (type.m1 == "continuous") {
      mod.m1 <- lm(formula.m1,
                   data = data,
                   weights = weights)
    } else if (type.m1 == "binary") {
      mod.m1 <- glm(formula.m1,
                    data = data,
                    weights = weights,
                    family = binomial(link = 'logit'))
    }
    v.m1 <- vcov(mod.m1)
    b.m1 <- coef(mod.m1)
    
    ## Mediator 2
    if (!is.null(formula.m2)) {
      if (type.m2 == "continuous") {
        mod.m2 <- lm(formula.m2,
                     data = data,
                     weights = weights)
      } else if (type.m2 == "binary") {
        mod.m2 <- glm(formula.m2,
                      data = data,
                      weights = weights,
                      family = binomial(link = 'logit'))
      }
      v.m2 <- vcov(mod.m2)
      b.m2 <- coef(mod.m2)
    }
    
    ## Mediator 3
    if (!is.null(formula.m3)) {
      if (type.m3 == "continuous") {
        mod.m3 <- lm(formula.m3,
                     data = data,
                     weights = weights)
      } else if (type.m3 == "binary") {
        mod.m3 <- glm(formula.m3,
                      data = data,
                      weights = weights,
                      family = binomial(link = 'logit'))
      }
      v.m3 <- vcov(mod.m3)
      b.m3 <- coef(mod.m3)
    }
  }
  
  ## Extract model information
  call.y <- as.formula(paste(as.character(mod.y$terms)[c(1, 3)], sep = " "))
  call.m1 <- as.formula(paste(as.character(mod.m1$terms)[c(1, 3)], sep = " "))
  m1.name <- as.character(mod.m1$terms)[2]
  if (!is.null(formula.m2)) {
    call.m2 <- 
      as.formula(paste(as.character(mod.m2$terms)[c(1, 3)], sep = " "))
    m2.name <- as.character(mod.m2$terms)[2]
  }
  if (!is.null(formula.m3)) {
    call.m3 <- 
      as.formula(paste(as.character(mod.m3$terms)[c(1, 3)], sep = " "))
    m3.name <- as.character(mod.m3$terms)[2]
  }
  j.names <- mod.y$lab
  k.names <- mod.y$coefnames
  weights <- mod.y$weights
  
  ## Constants
  J <- nrow(coef(mod.y)) + 1L
  K <- ncol(coef(mod.y))
  
  ## Parameter Simulation
  print("2 Simulating Model Parameters")
  if (imp) {
    b.sim <- m1.sim <- NULL
    if (!is.null(formula.m2)) m2.sim <- NULL
    if (!is.null(formula.m3)) m3.sim <- NULL
    for(i in seq_len(M)) {
      ## Outcome
      coef.vec <- as.vector(t(b.y[[i]]))
      b.sim.tmp <- mvrnorm(S, coef.vec, v.y[[i]])
      for (k in seq_len(K)) b.sim.tmp <- cbind(0, b.sim.tmp)
      b.sim <- rbind(b.sim, b.sim.tmp)
      rm(b.sim.tmp)
      
      ## Mediator 1
      m1.sim <- rbind(m1.sim, mvrnorm(S, b.m1[[i]], v.m1[[i]]))
      
      ## Mediator 2
      if (!is.null(formula.m2)) {
        m2.sim <- rbind(m2.sim, mvrnorm(S, b.m2[[i]], v.m2[[i]]))  
      }
      
      ## Mediator 3
      if (!is.null(formula.m3)) {
        m3.sim <- rbind(m3.sim, mvrnorm(S, b.m3[[i]], v.m3[[i]]))
      }
    }
  } else {
    ## Outcome
    coef.vec <- as.vector(t(b))
    b.sim <- mvrnorm(S, coef.vec, v)
    for (k in seq_len(K)) b.sim <- cbind(0, b.sim)
    
    ## Mediator 1
    m1.sim <- mvrnorm(S, b.m1, v.m1)
    
    ## Mediator 2
    if (!is.null(formula.m2)) {
      m2.sim <- mvrnorm(S, b.m2, v.m2)  
    }
    
    ## Mediator 3
    if (!is.null(formula.m3)) {
      m3.sim <- mvrnorm(S, b.m3, v.m3)
    }
  }
  
  ## Simulations
  b <- array(NA, c(J, M * S, K))
  for (j in seq_len(J)) {
    from <- (j - 1L) * K + 1L
    to <- j * K
    b[j, , ] <- b.sim[, from:to]
  }
  rm(b.sim)
  
  ## Data matrices
  if (imp) {
    X  <- lapply(data, function(d) model.matrix( call.y, data = d))
    X <- simplify2array(X) 
    Z1 <- lapply(data, function(d) model.matrix(call.m1, data = d))
    Z1 <- simplify2array(Z1) 
    if (!is.null(formula.m2)) {
      Z2 <- lapply(data, function(d) model.matrix(call.m2, data = d))
      Z2 <- simplify2array(Z2) 
    }
    if (!is.null(formula.m3)) {
      Z3 <- lapply(data, function(d) model.matrix(call.m3, data = d))
      Z3 <- simplify2array(Z3) 
    }
  } else {
    X <- model.matrix(call, data = data)
    Z1 <- model.matrix(call.m1, data = data)
    if (!is.null(formula.m2))
      Z2 <- model.matrix(call.m2, data = data)
    if (!is.null(formula.m3))
      Z3 <- model.matrix(call.m3, data = data)
  }
  N <- nrow(X)

  ## Find X/Z*-variables of interest
  if (M > 1) {
    X.tmp <- X[, , 1]
    Z1.tmp <- Z1[, , 1]
    if (!is.null(formula.m2)) Z2.tmp <- Z2[, , 1]
    if (!is.null(formula.m3)) Z3.tmp <- Z3[, , 1]
  } else {
    X.tmp <- X
    Z1.tmp <- Z1
    if (!is.null(formula.m2)) Z2.tmp <- Z2
    if (!is.null(formula.m3)) Z3.tmp <- Z3
  }
  x.index <- grepl(x.name, colnames(X.tmp))
  x.which <- which(x.index)
  x.which.m1 <- which(grepl(m1.name, colnames(X.tmp)))
  if (!is.null(formula.m2)) {
    x.which.m2 <- which(grepl(m2.name, colnames(X.tmp)))
    }
  if (!is.null(formula.m3)) {
    x.which.m3 <- which(grepl(m3.name, colnames(X.tmp)))
    }
  z1.which <- which(grepl(x.name, colnames(Z1.tmp)))
  if (!is.null(formula.m2)) {
    z2.which <- which(grepl(x.name, colnames(Z2.tmp)))
    }
  if (!is.null(formula.m3)) {
    z3.which <- which(grepl(x.name, colnames(Z3.tmp)))
    }
  if (sum(x.index) > 1) {
    x.type <- "cat"
    n.cat <- sum(x.index)
  } else if (length(unique(X.tmp[, x.index])) > 2) {
    x.type <- "cont"
  } else {
    x.type <- "bin"
    n.cat <- 1L
  }
  rm(X.tmp, Z1.tmp)
  if (!is.null(formula.m2)) rm(Z2.tmp)
  if (!is.null(formula.m3)) rm(Z3.tmp)

  ## Average X/Z* over imputations
  if (M > 1) {
    X <- apply(X, 1:2, mean)
    Z1 <- apply(Z1, 1:2, mean)
    if (!is.null(formula.m2)) Z2 <- apply(Z2, 1:2, mean)
    if (!is.null(formula.m3)) Z3 <- apply(Z3, 1:2, mean)
  }

  ## Predictions
  print("3 Calculating requested quantities")
  ## Define Shifts
  if (x.shift == "unit") {
    shift <- 1L
  } else if (x.shift == "sd") {
    shift <- sd(X[, x.which])
  }
  
  ## Counterfactual predictions of the mediator
  m1.hat <- array(NA, c(M * S, N, 3L))
  if (!is.null(formula.m2)) m2.hat <- array(NA, c(M * S, N, 3L))
  if (!is.null(formula.m3)) m3.hat <- array(NA, c(M * S, N, 3L))
  for (c in 1:3) {
    ## Data
    Z1.tmp <- Z1
    if (x.type == "cont") {
      if (c < 3) Z1.tmp[, x.which] <- Z1.tmp[, z1.which] + (c - 1.5) * shift
      if (!is.null(formula.m2)) {
        Z2.tmp <- Z2
        if (c < 3) Z2.tmp[, x.which] <- Z2.tmp[, z2.which] + (c - 1.5) * shift
      }
      if (!is.null(formula.m3)) {
        Z3.tmp <- Z3
        if (c < 3) Z3.tmp[, x.which] <- Z3.tmp[, z3.which] + (c - 1.5) * shift
      }
    } else {
      if (c < 3) Z1.tmp[, x.which] <- c - 1
      if (!is.null(formula.m2)) {
        Z2.tmp <- Z2
        if (c < 3) Z2.tmp[, x.which] <- c - 1
      }
      if (!is.null(formula.m3)) {
        Z3.tmp <- Z3
        if (c < 3) Z3.tmp[, x.which] <- c - 1
      }
    }
    
    ## Predictions
    m1.hat[, , c] <- m1.sim %*% t(Z1.tmp)
    if (type.m1 == "binary") m1.hat[, , c] <- invlogit(m1.hat[, , c])
    if (c == 2) {
      m1.ame <- quantile(apply(m1.hat[, , 2] - m1.hat[, , 1],
                               1, weighted.mean, w),
                         c(.5, .025, .975))
    }
    if (!is.null(formula.m2)) {
      m2.hat[, , c] <- m2.sim %*% t(Z2.tmp)
      if (type.m2 == "binary") m2.hat[, , c] <- invlogit(m2.hat[, , c])
      if (c == 2) {
        m2.ame <- quantile(apply(m2.hat[, , 2] - m2.hat[, , 1], 
                                 1, weighted.mean, w),
                           c(.5, .025, .975))
      }
    } else {
      m2.ame <- NULL
    }
    if (!is.null(formula.m3)) {
      m3.hat[, , c] <- m3.sim %*% t(Z3.tmp)
      if (type.m3 == "binary") m3.hat[, , c] <- invlogit(m3.hat[, , c])
      if (c == 2) {
        m3.ame <- quantile(apply(m3.hat[, , 2] - m3.hat[, , 1], 
                                 1, weighted.mean, w),
                           c(.5, .025, .975))
      }
    } else {
      m3.ame <- NULL
    }
  }  
  
  ## Direct effects  
  de.sim <- array(NA, c(J, M * S, N, 2L))
  exp.xb.sim <-  array(NA, c(J, N, 2L))
  for (c in 1:2) {
    X.tmp <- X
    if (x.type == "cont") {
      X.tmp[, x.which] <- X.tmp[, x.which] + (c - 1.5) * shift
    } else {
      X.tmp[, x.which] <- c - 1
    }
    for (s in seq_len(M * S)) {
      X.tmp[, x.which.m1] <- m1.hat[s, , 3]
      if (!is.null(formula.m2)) X.tmp[, x.which.m2] <- m2.hat[s, , 3]
      if (!is.null(formula.m3)) X.tmp[, x.which.m3] <- m3.hat[s, , 3]
      for (j in seq_len(J)) {
        exp.xb.sim[j, , c] <- exp(b[j, s, ] %*% t(X.tmp))
      }
      for (j in seq_len(J)) {
        denom <- apply(exp.xb.sim[, , c], 2, sum)
        de.sim[j, s, , c] <- exp.xb.sim[j, , c] / denom
      }
    }
  }
  de.sim <- de.sim[, , , 2] - de.sim[, , , 1] / abs(shift)
  de.sim <- apply(de.sim, c(1, 2), weighted.mean, w)
  de <- apply(de.sim, 1, quantile, c(.5, .025, .975))
  
  ## Indirect effects (Mediator 1)
  me1.sim <- array(NA, c(J, M * S, N, 2L))
  exp.xb.sim <-  array(NA, c(J, N, 2L))
  for (c in 1:2) {
    X.tmp <- X
    for (s in seq_len(M * S)) {
      X.tmp[, x.which.m1] <- m1.hat[s, , c]
      if (!is.null(formula.m2)) X.tmp[, x.which.m2] <- m2.hat[s, , 3]
      if (!is.null(formula.m3)) X.tmp[, x.which.m3] <- m3.hat[s, , 3]
      for (j in seq_len(J)) {
        exp.xb.sim[j, , c] <- exp(b[j, s, ] %*% t(X.tmp))
      }
      for (j in seq_len(J)) {
        denom <- apply(exp.xb.sim[, , c], 2, sum)
        me1.sim[j, s, , c] <- exp.xb.sim[j, , c] / denom
      }
    }
  }
  me1.sim <- me1.sim[, , , 2] - me1.sim[, , , 1] / abs(shift)
  me1.sim <- apply(me1.sim, c(1, 2), weighted.mean, w)
  me1 <- apply(me1.sim, 1, quantile, c(.5, .025, .975))
  
  ## Indirect effects (Mediator 2)
  if (!is.null(formula.m2)) {
    me2.sim <- array(NA, c(J, M * S, N, 2L))
    exp.xb.sim <-  array(NA, c(J, N, 2L))
    for (c in 1:2) {
      X.tmp <- X
      for (s in seq_len(M * S)) {
        X.tmp[, x.which.m1] <- m1.hat[s, , 3]
        X.tmp[, x.which.m2] <- m2.hat[s, , c]
        if (!is.null(formula.m3)) X.tmp[, x.which.m3] <- m3.hat[s, , 3]
        for (j in seq_len(J)) {
          exp.xb.sim[j, , c] <- exp(b[j, s, ] %*% t(X.tmp))
        }
        for (j in seq_len(J)) {
          denom <- apply(exp.xb.sim[, , c], 2, sum)
          me2.sim[j, s, , c] <- exp.xb.sim[j, , c] / denom
        }
      }
    }
    me2.sim <- me2.sim[, , , 2] - me2.sim[, , , 1] / abs(shift)
    me2.sim <- apply(me2.sim, c(1, 2), weighted.mean, w)
    me2 <- apply(me2.sim, 1, quantile, c(.5, .025, .975))
  } else {
    me2 <- NULL
  }
  
  ## Indirect effects (Mediator 3)
  if (!is.null(formula.m3)) {
    me3.sim <- array(NA, c(J, M * S, N, 2L))
    exp.xb.sim <-  array(NA, c(J, N, 2L))
    for (c in 1:2) {
      X.tmp <- X
      for (s in seq_len(M * S)) {
        X.tmp[, x.which.m1] <- m1.hat[s, , 3]
        if (!is.null(formula.m2)) X.tmp[, x.which.m2] <- m2.hat[s, , 3]
        X.tmp[, x.which.m3] <- m3.hat[s, , c]
        for (j in seq_len(J)) {
          exp.xb.sim[j, , c] <- exp(b[j, s, ] %*% t(X.tmp))
        }
        for (j in seq_len(J)) {
          denom <- apply(exp.xb.sim[, , c], 2, sum)
          me3.sim[j, s, , c] <- exp.xb.sim[j, , c] / denom
        }
      }
    }
    me3.sim <- me3.sim[, , , 2] - me3.sim[, , , 1] / abs(shift)
    me3.sim <- apply(me3.sim, c(1, 2), weighted.mean, w)
    me3 <- apply(me3.sim, 1, quantile, c(.5, .025, .975))
  } else {
    me3 <- NULL
  }
  
  ## Total effects
  te.sim <- de.sim + me1.sim
  if (!is.null(formula.m2)) te.sim <- te.sim + me2.sim
  if (!is.null(formula.m3)) te.sim <- te.sim + me3.sim
  te <- apply(te.sim, 1, quantile, c(.5, .025, .975))
  
  ## Labeling
  dimnames(de) <- dimnames(te) <- dimnames(me1) <- 
    list(c("Est.", "2.5%", "97.5%"), j.names)
  if (!is.null(formula.m2)) 
    dimnames(me2) <- list(c("Est.", "2.5%", "97.5%"), j.names)
  if (!is.null(formula.m3))
    dimnames(me3) <- list(c("Est.", "2.5%", "97.5%"), j.names)
  
  ## Value
  print("4 Returning Output")
  out <- list(te = te,
              de = de,
              me1 = me1,
              me2 = me2,
              me3 = me3,
              m1.ame = m1.ame,
              m2.ame = m2.ame,
              m3.ame = m3.ame)
  return(out)
}